package pers.apricot.service

import com.coveo.saml.SamlClient
import com.coveo.saml.SamlResponse
import groovy.transform.CompileStatic
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.http.HttpMethod
import org.springframework.stereotype.Service

import javax.annotation.PostConstruct
import java.nio.charset.StandardCharsets

@Service
@CompileStatic
class SamlService {

    @Autowired
    SamlProperties samlProperties

    SamlClient samlClient

    @PostConstruct
    void init() {
        if (samlProperties.enabled) {
            if (!samlProperties.metadataLocation) {
                throw new IllegalArgumentException("saml miss config metadataLocation")
            }
            InputStream inputStream = SamlService.getResourceAsStream(samlProperties.metadataLocation)
            if (!inputStream) {
                throw new IllegalArgumentException("saml error config metadataLocation")
            }
            Reader reader = new InputStreamReader(inputStream, StandardCharsets.UTF_8)

            samlClient = SamlClient.fromMetadata(
                    samlProperties.relyingPartyIdentifier, samlProperties.assertionConsumerServiceUrl,
                    reader)
            if (samlProperties.publicKeyCrt && samlProperties.privateKeyPk8) {
                samlClient.setSPKeys(SamlService.class.getResource(samlProperties.publicKeyCrt).getFile(),
                        SamlService.class.getResource(samlProperties.privateKeyPk8).getFile())
            }
        }
    }

    String buildSmalRequest() {
        return samlClient.getSamlRequest()
    }

    /**
     * 解析saml返回,返回用户名称
     * @param samlResponse
     * @return
     */
    String parseSamlResponse(String samlResponse) {
        SamlResponse response = samlClient.decodeAndValidateSamlResponse(samlResponse, HttpMethod.POST.toString())
        return response.getNameID()
    }
}
